{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-09T15:28:03.083391Z",
     "iopub.status.busy": "2025-02-09T15:28:03.082960Z",
     "iopub.status.idle": "2025-02-09T15:28:12.294769Z",
     "shell.execute_reply": "2025-02-09T15:28:12.293484Z",
     "shell.execute_reply.started": "2025-02-09T15:28:03.083356Z"
    },
    "trusted": true
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "\n",
    "sudo rm -rf torchtitan\n",
    "git clone -q https://github.com/pytorch/torchtitan\n",
    "git -C torchtitan checkout -q 49c6d6fc15ef644e5c3b1003ad4e0d9ea5fcb9a9\n",
    "curl -s https://gist.githubusercontent.com/antony-frolov/c2e69bbda2b4418b1ab1c99839c55877/raw/c873709f6fe34dbf8ba678302e4fa92d6ed8c7f1/1b.patch -o 1b.patch\n",
    "patch -s -p1 -i ../1b.patch -d torchtitan\n",
    "sudo pip install -q fire triton -r ./torchtitan/requirements.txt ./torchtitan\n",
    "sudo apt-get update -qq && sudo apt-get install -qq pciutils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-09T16:08:10.559856Z",
     "iopub.status.busy": "2025-02-09T16:08:10.559432Z",
     "iopub.status.idle": "2025-02-09T16:08:10.571441Z",
     "shell.execute_reply": "2025-02-09T16:08:10.570076Z",
     "shell.execute_reply.started": "2025-02-09T16:08:10.559821Z"
    },
    "trusted": true
   },
   "outputs": [],
   "source": [
    "%%writefile train.py\n",
    "import functools\n",
    "import os\n",
    "import pickle\n",
    "import time\n",
    "from typing import Optional\n",
    "\n",
    "import fire\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.distributed import DeviceMesh, init_device_mesh\n",
    "from torch.distributed._composable.fsdp import (\n",
    "    CPUOffloadPolicy,\n",
    "    MixedPrecisionPolicy,\n",
    "    fully_shard,\n",
    ")\n",
    "from torch.optim.lr_scheduler import LambdaLR\n",
    "\n",
    "import torchtitan.utils as utils\n",
    "from torchtitan.datasets import build_hf_data_loader, build_tokenizer\n",
    "from torchtitan.logging import init_logger, logger\n",
    "from torchtitan.metrics import build_device_memory_monitor\n",
    "from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config\n",
    "from torchtitan.optimizer import linear_warmup_linear_decay\n",
    "\n",
    "\n",
    "def trace_handler(prof, trace_dir: str):\n",
    "    curr_trace_dir_name = \"iteration_\" + str(prof.step_num)\n",
    "    curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)\n",
    "    if not os.path.exists(curr_trace_dir):\n",
    "        os.makedirs(curr_trace_dir, exist_ok=True)\n",
    "\n",
    "    logger.info(f\"Dumping profiler traces at step {prof.step_num}\")\n",
    "    begin = time.monotonic()\n",
    "    prof.export_chrome_trace(\n",
    "        f\"{curr_trace_dir}/rank{torch.distributed.get_rank()}_trace.json\"\n",
    "    )\n",
    "    logger.info(\n",
    "        f\"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds\"\n",
    "    )\n",
    "\n",
    "\n",
    "class MemoryProfiler:\n",
    "    def __init__(\n",
    "        self,\n",
    "        step_num: int,\n",
    "        freq: int,\n",
    "        snapshot_dir: str,\n",
    "        dir_name: Optional[str] = None,\n",
    "    ):\n",
    "        self.snapshot_dir = snapshot_dir\n",
    "        if not os.path.exists(snapshot_dir):\n",
    "            os.makedirs(snapshot_dir, exist_ok=True)\n",
    "\n",
    "        # when resume training, we start from the last step\n",
    "        self.step_num = step_num\n",
    "        self.freq = freq\n",
    "\n",
    "        self.dir_name = dir_name\n",
    "\n",
    "    def step(self):\n",
    "        self.step_num += 1\n",
    "        if self.step_num % self.freq not in [0, self.freq - 1]:\n",
    "            return\n",
    "        if self.step_num % self.freq == self.freq - 1:\n",
    "            torch.cuda.memory._record_memory_history()\n",
    "            return\n",
    "        curr_step = self.step_num\n",
    "        if self.dir_name is None:\n",
    "            dir_name = f\"iteration_{curr_step}\"\n",
    "        else:\n",
    "            dir_name = self.dir_name\n",
    "        curr_snapshot_dir = os.path.join(self.snapshot_dir, dir_name)\n",
    "        if not os.path.exists(curr_snapshot_dir):\n",
    "            os.makedirs(curr_snapshot_dir, exist_ok=True)\n",
    "        logger.info(f\"Dumping memory snapshot at step {curr_step}\")\n",
    "        begin = time.monotonic()\n",
    "        with open(\n",
    "            f\"{curr_snapshot_dir}/rank{torch.distributed.get_rank()}_memory_snapshot.pickle\",\n",
    "            \"wb\",\n",
    "        ) as output:\n",
    "            pickle.dump(torch.cuda.memory._snapshot(), output)\n",
    "        torch.cuda.memory._record_memory_history(None)\n",
    "        logger.info(\n",
    "            f\"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds\"\n",
    "        )\n",
    "\n",
    "\n",
    "def apply_fsdp(\n",
    "    model: nn.Module,\n",
    "    dp_mesh: DeviceMesh,\n",
    "    param_dtype: torch.dtype,\n",
    "    reduce_dtype: torch.dtype,\n",
    "    cpu_offload: bool,\n",
    "    reshard_after_forward: bool,\n",
    "):\n",
    "    mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)\n",
    "    fsdp_config = {\"mesh\": dp_mesh, \"mp_policy\": mp_policy}\n",
    "    if cpu_offload:\n",
    "        fsdp_config[\"offload_policy\"] = CPUOffloadPolicy()\n",
    "\n",
    "    for layer_id, transformer_block in model.layers.items():\n",
    "        fully_shard(\n",
    "            transformer_block,\n",
    "            **fsdp_config,\n",
    "            reshard_after_forward=reshard_after_forward,\n",
    "        )\n",
    "    fully_shard(model, **fsdp_config)\n",
    "\n",
    "\n",
    "def train(\n",
    "    lr: float = 8e-4,\n",
    "    max_norm: float = 1.0,\n",
    "    training_steps: int = 10,\n",
    "    warmup_steps: int = 2,\n",
    "    batch_size: int = 8,\n",
    "    seq_len: int = 2048,\n",
    "    model_name: str = \"llama3\",\n",
    "    flavor: str = \"debugmodel\",\n",
    "    norm_type: str = \"rmsnorm\",\n",
    "    enable_cpu_offload: bool = False,\n",
    "    param_dtype: str = \"float32\",\n",
    "    reduce_dtype: str = \"float32\",\n",
    "    reshard_after_forward: bool = True,\n",
    "    reshard_after_forward_degree: int | None = None,\n",
    "    device_type: str = \"cuda\",\n",
    "    log_freq: int = 1,\n",
    "    gc_freq: int = 50,\n",
    "    profile_freq: int = 10,\n",
    "    profile_active: int = 1,\n",
    "    profile_warmup: int = 3,\n",
    "    dump_folder: str = \".\",\n",
    "    save_traces_folder: str = \"profile_trace\",\n",
    "    save_memory_snapshot_folder: str = \"memory_snapshot\",\n",
    "    apply_compile: bool = False,\n",
    "    num_gas_steps: int = 1,\n",
    "    reshard_after_backward: bool = True,\n",
    "    reduce_grads: bool = True,\n",
    "):\n",
    "    decay_steps = training_steps - warmup_steps\n",
    "    param_dtype = getattr(torch, param_dtype)\n",
    "    reduce_dtype = getattr(torch, reduce_dtype)\n",
    "    if reshard_after_forward_degree is not None:\n",
    "        assert reshard_after_forward\n",
    "        reshard_after_forward = reshard_after_forward_degree\n",
    "\n",
    "    init_logger()\n",
    "\n",
    "    # take control of garbage collection to avoid stragglers\n",
    "    gc_handler = utils.GarbageCollection(gc_freq=gc_freq)\n",
    "\n",
    "    # init distributed\n",
    "    world_size = int(os.environ[\"WORLD_SIZE\"])\n",
    "    device = torch.device(f\"{device_type}:{int(os.environ['LOCAL_RANK'])}\")\n",
    "    torch.cuda.set_device(device)\n",
    "    if not torch.distributed.is_initialized():\n",
    "        torch.distributed.init_process_group(\"cuda:nccl,cpu:gloo\")\n",
    "    # initialize device memory monitor and get peak flops for MFU calculation\n",
    "    device_memory_monitor = build_device_memory_monitor()\n",
    "    gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)\n",
    "    logger.info(f\"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}\")\n",
    "\n",
    "    # build meshes\n",
    "    world_mesh = init_device_mesh(device_type, (world_size,), mesh_dim_names=(\"dp\",))\n",
    "    dp_mesh = world_mesh[\"dp\"]\n",
    "    dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()\n",
    "\n",
    "    # build tokenizer\n",
    "    tokenizer_type = model_name_to_tokenizer[model_name]\n",
    "    tokenizer = build_tokenizer(\n",
    "        tokenizer_type, \"torchtitan/tests/assets/test_tiktoken.model\"\n",
    "    )\n",
    "    # build dataloader\n",
    "    data_loader = build_hf_data_loader(\n",
    "        \"c4_test\",\n",
    "        \"torchtitan/tests/assets/c4_test\",\n",
    "        tokenizer,\n",
    "        batch_size=batch_size,\n",
    "        seq_len=seq_len,\n",
    "        world_size=dp_degree,\n",
    "        rank=dp_rank,\n",
    "    )\n",
    "\n",
    "    # build model (using meta init)\n",
    "    model_cls = model_name_to_cls[model_name]\n",
    "    model_config = models_config[model_name][flavor]\n",
    "    model_config.norm_type = norm_type\n",
    "    model_config.vocab_size = tokenizer.n_words\n",
    "    model_config.max_seq_len = seq_len\n",
    "\n",
    "    logger.info(f\"Building {model_name} {flavor} with {model_config}\")\n",
    "    memory_profiler = MemoryProfiler(\n",
    "        profile_freq - 2,\n",
    "        profile_freq,\n",
    "        snapshot_dir=os.path.join(dump_folder, save_memory_snapshot_folder),\n",
    "        dir_name=\"model_init\",\n",
    "    )\n",
    "    memory_profiler.step()\n",
    "    with torch.device(\"meta\"):\n",
    "        model = model_cls.from_model_args(model_config)\n",
    "\n",
    "    # log model size\n",
    "    model_param_count = utils.get_num_params(model)\n",
    "    num_flop_per_token = utils.get_num_flop_per_token(\n",
    "        utils.get_num_params(model, exclude_embedding=True),\n",
    "        model_config,\n",
    "        seq_len,\n",
    "    )\n",
    "    logger.info(\n",
    "        f\"Model {model_name} {flavor} \" f\"size: {model_param_count:,} total parameters\"\n",
    "    )\n",
    "\n",
    "    # loss function\n",
    "    def loss_fn(pred, labels):\n",
    "        return torch.nn.functional.cross_entropy(\n",
    "            pred.flatten(0, 1).float(), labels.flatten(0, 1)\n",
    "        )\n",
    "\n",
    "    # move sharded model to CPU/GPU and initialize weights via DTensor\n",
    "    if enable_cpu_offload:\n",
    "        init_device = \"cpu\"\n",
    "        buffer_device = device_type\n",
    "    else:\n",
    "        init_device = device_type\n",
    "        buffer_device = None\n",
    "\n",
    "    # apply parallelisms and initialization\n",
    "    if apply_compile:\n",
    "        for layer_id, transformer_block in model.layers.named_children():\n",
    "            transformer_block = torch.compile(transformer_block, fullgraph=True)\n",
    "            model.layers.register_module(layer_id, transformer_block)\n",
    "        logger.info(\"Compiling each TransformerBlock with torch.compile\")\n",
    "    apply_fsdp(\n",
    "        model,\n",
    "        dp_mesh=dp_mesh,\n",
    "        param_dtype=param_dtype,\n",
    "        reduce_dtype=reduce_dtype,\n",
    "        cpu_offload=enable_cpu_offload,\n",
    "        reshard_after_forward=reshard_after_forward,\n",
    "    )\n",
    "    model.to_empty(device=init_device)\n",
    "    with torch.no_grad():\n",
    "        model.init_weights(buffer_device=buffer_device)\n",
    "    model.train()\n",
    "\n",
    "    memory_profiler.step()\n",
    "\n",
    "    device_mem_stats = device_memory_monitor.get_peak_stats()\n",
    "    logger.info(\n",
    "        f\"{device_type.upper()} memory usage for model: \"\n",
    "        f\"{device_mem_stats.max_reserved_gib:.2f}GiB\"\n",
    "        f\"({device_mem_stats.max_reserved_pct:.2f}%)\"\n",
    "    )\n",
    "\n",
    "    optimizer = torch.optim.AdamW(\n",
    "        model.parameters(),\n",
    "        lr=lr,\n",
    "        betas=(0.9, 0.95),\n",
    "        weight_decay=0.1,\n",
    "        fused=True,\n",
    "    )\n",
    "    lr_scheduler = LambdaLR(\n",
    "        optimizer,\n",
    "        lr_lambda=functools.partial(\n",
    "            linear_warmup_linear_decay, warmup_steps, decay_steps\n",
    "        ),\n",
    "    )\n",
    "\n",
    "    data_iterator = iter(data_loader)\n",
    "\n",
    "    train_context = utils.get_train_context(\n",
    "        enable_loss_parallel=False,\n",
    "        enable_compiled_autograd=False,\n",
    "    )\n",
    "\n",
    "    # variables used to keep info for metrics logging\n",
    "    step = 0\n",
    "    ntokens_since_last_log = 0\n",
    "    data_loading_times = []\n",
    "    time_last_log = time.perf_counter()\n",
    "    device_memory_monitor.reset_peak_stats()\n",
    "\n",
    "    # train loop\n",
    "    logger.info(\n",
    "        f\"Training starts at step {step + 1}, \"\n",
    "        f\"with local batch size {batch_size}, \"\n",
    "        f\"global batch size {batch_size * dp_degree}, \"\n",
    "        f\"sequence length {seq_len}, \"\n",
    "        f\"total steps {training_steps} \"\n",
    "        f\"(warmup {warmup_steps})\"\n",
    "    )\n",
    "    with torch.profiler.profile(\n",
    "        activities=[\n",
    "            torch.profiler.ProfilerActivity.CPU,\n",
    "            torch.profiler.ProfilerActivity.CUDA,\n",
    "        ],\n",
    "        schedule=torch.profiler.schedule(\n",
    "            wait=profile_freq - (profile_active + profile_warmup),\n",
    "            warmup=profile_warmup,\n",
    "            active=profile_active,\n",
    "        ),\n",
    "        on_trace_ready=functools.partial(\n",
    "            trace_handler, trace_dir=os.path.join(dump_folder, save_traces_folder)\n",
    "        ),\n",
    "        record_shapes=True,\n",
    "    ) as torch_profiler:\n",
    "        while step < training_steps:\n",
    "            memory_profiler = MemoryProfiler(\n",
    "                step,\n",
    "                profile_freq,\n",
    "                snapshot_dir=os.path.join(dump_folder, save_memory_snapshot_folder),\n",
    "            )\n",
    "\n",
    "            step += 1\n",
    "            gc_handler.run(step)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            for gas_step in range(num_gas_steps):\n",
    "                is_last_backward = gas_step == num_gas_steps - 1\n",
    "                model.set_is_last_backward(is_last_backward)\n",
    "                model.set_reshard_after_backward(\n",
    "                    reshard_after_backward or is_last_backward\n",
    "                )\n",
    "                model.set_requires_gradient_sync(reduce_grads or is_last_backward)\n",
    "\n",
    "                # get batch\n",
    "                data_load_start = time.perf_counter()\n",
    "                batch = next(data_iterator)\n",
    "                input_ids, labels = batch\n",
    "                ntokens_since_last_log += labels.numel()\n",
    "                data_loading_times.append(time.perf_counter() - data_load_start)\n",
    "\n",
    "                input_ids = input_ids.to(device_type)\n",
    "                labels = labels.to(device_type)\n",
    "\n",
    "                # Non-PP forward / backward\n",
    "                with train_context():\n",
    "                    pred = model(input_ids)\n",
    "                    loss = loss_fn(pred, labels)\n",
    "                    # pred.shape=(bs, seq_len, vocab_size)\n",
    "                    # need to free to before bwd to avoid peaking memory\n",
    "                    del pred\n",
    "                    loss.backward()\n",
    "\n",
    "            # clip gradients\n",
    "            torch.nn.utils.clip_grad_norm_([p for p in model.parameters()], max_norm)\n",
    "\n",
    "            # optimizer step\n",
    "            optimizer.step()\n",
    "            lr_scheduler.step()\n",
    "\n",
    "            # log metrics\n",
    "            if step == 1 or step % log_freq == 0:\n",
    "                loss = loss.detach()\n",
    "                global_avg_loss = utils.dist_mean(loss, dp_mesh)\n",
    "\n",
    "                time_delta = time.perf_counter() - time_last_log\n",
    "\n",
    "                # tokens per second per device, abbreviated as tps\n",
    "                tps = ntokens_since_last_log / time_delta\n",
    "                # model FLOPS utilization\n",
    "                # For its definition and calculation, please refer to the PaLM paper:\n",
    "                # https://arxiv.org/abs/2204.02311\n",
    "                mfu = 100 * num_flop_per_token * tps / gpu_peak_flops\n",
    "\n",
    "                device_mem_stats = device_memory_monitor.get_peak_stats()\n",
    "\n",
    "                logger.info(\n",
    "                    f\"step: {step:2}  \"\n",
    "                    f\"loss: {global_avg_loss:7.4f}  \"\n",
    "                    f\"memory: {device_mem_stats.max_reserved_gib:5.2f}GiB\"\n",
    "                    f\"({device_mem_stats.max_reserved_pct:.2f}%)  \"\n",
    "                    f\"tps: {round(tps):,}  \"\n",
    "                    f\"mfu: {mfu:.2f}%\"\n",
    "                )\n",
    "\n",
    "                ntokens_since_last_log = 0\n",
    "                data_loading_times.clear()\n",
    "                time_last_log = time.perf_counter()\n",
    "                device_memory_monitor.reset_peak_stats()\n",
    "\n",
    "            # signal the profiler that the next profiling step has started\n",
    "            if torch_profiler:\n",
    "                torch_profiler.step()\n",
    "            if memory_profiler:\n",
    "                memory_profiler.step()\n",
    "\n",
    "    logger.info(\"Training completed\")\n",
    "\n",
    "    torch.distributed.destroy_process_group()\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    fire.Fire(train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!OMP_NUM_THREADS=1 \\\n",
    "    torchrun \\\n",
    "    --local-ranks-filter 0 \\\n",
    "    --nproc-per-node 2 \\\n",
    "    train.py \\\n",
    "        --flavor 1B \\\n",
    "        --batch-size 2 \\\n",
    "        --seq-len 1024 \\\n",
    "        --training-steps 20 \\\n",
    "        --warmup-steps 5 \\\n",
    "        --gc-freq 5 \\\n",
    "        --profile-freq 10 \\\n",
    "        \\\n",
    "        --param-dtype float16 \\\n",
    "        --reduce-dtype float16"
   ]
  }
 ],
 "metadata": {
  "kaggle": {
   "accelerator": "none",
   "dataSources": [],
   "dockerImageVersionId": 30887,
   "isGpuEnabled": false,
   "isInternetEnabled": true,
   "language": "python",
   "sourceType": "notebook"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
